--- title: SRCNN: Image super-resolution using deep convolutional networks keywords: fastai sidebar: home_sidebar ---
{% raw %}
{% endraw %} {% raw %}
%reload_ext autoreload
%autoreload 2
%matplotlib inline
{% endraw %} {% raw %}
#import PIL
#from pathlib import PosixPath
{% endraw %} {% raw %}
{% endraw %} {% raw %}
import sys
sys.path.append('..')
from superres.datasets import *
from superres.databunch import *
{% endraw %} {% raw %}
seed = 8610
random.seed(seed)
np.random.seed(seed)
{% endraw %}

Model

{% raw %}
{% endraw %} {% raw %}

class SRCNN[source]

SRCNN() :: Module

Image super-resolution using deep convolutional networks

{% endraw %}

DataBunch

{% raw %}
train_hr = div2k_train_hr_crop_256
{% endraw %} {% raw %}
in_size = 256
out_size = 256
scale = 4
bs = 10
{% endraw %} {% raw %}
data = create_sr_databunch(train_hr, in_size=in_size, out_size=out_size, scale=scale, bs=bs, seed=seed)
print(data)
data.show_batch()
ImageDataBunch;

Train: LabelList (25245 items)
x: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Valid: LabelList (6311 items)
x: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Test: None
{% endraw %}

Training

{% raw %}
model = SRCNN()
loss_func = MSELossFlat()
metrics = [m_psnr, m_ssim]
learn = Learner(data, model, loss_func=loss_func, metrics=metrics)
model_name = model.__class__.__name__
{% endraw %} {% raw %}
lr_find(learn)
learn.recorder.plot(suggestion=True)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 7.59E-07
Min loss divided by 10: 1.58E-03
{% endraw %} {% raw %}
lr = 1e-3
lrs = slice(lr)
epoch = 3
pct_start = 0.3
wd = 1e-3
save_fname = model_name
{% endraw %} {% raw %}
callbacks = [ShowGraph(learn), SaveModelCallback(learn, name=save_fname)]
{% endraw %} {% raw %}
learn.fit_one_cycle(epoch, lrs, pct_start=pct_start, wd=wd, callbacks=callbacks)
epoch train_loss valid_loss m_psnr m_ssim time
0 0.108567 0.089277 26.175465 0.395728 02:06
1 0.100952 0.063409 29.430597 0.430849 02:03
2 0.100369 0.059472 31.225027 0.439113 02:05
Better model found at epoch 0 with valid_loss value: 0.08927701413631439.
Better model found at epoch 1 with valid_loss value: 0.06340917944908142.
Better model found at epoch 2 with valid_loss value: 0.059471651911735535.
{% endraw %} {% raw %}
learn.show_results()
{% endraw %}

Test

{% raw %}
test_hr = set14_hr
{% endraw %} {% raw %}
il_test_x = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=in_size, scale=4, sizeup=True))
il_test_y = ImageImageList.from_folder(test_hr, after_open=partial(after_open_image, size=out_size))
{% endraw %} {% raw %}
learn.load(save_fname)
Learner(data=ImageDataBunch;

Train: LabelList (25245 items)
x: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Valid: LabelList (6311 items)
x: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
y: ImageImageList
Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)
Path: /home/jovyan/notebook/datasets/DIV2K/DIV2K_train_HR_crop/256;

Test: None, model=SRCNN(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (activate): ReLU()
), opt_func=functools.partial(<class 'torch.optim.adam.Adam'>, betas=(0.9, 0.99)), loss_func=FlattenedLoss of MSELoss(), metrics=[<function m_psnr at 0x7f45e2a80598>, <function m_ssim at 0x7f45e2a80620>], true_wd=True, bn_wd=True, wd=0.01, train_bn=True, path=PosixPath('.'), model_dir='models', callback_fns=[functools.partial(<class 'fastai.basic_train.Recorder'>, add_time=True, silent=False)], callbacks=[], layer_groups=[Sequential(
  (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (1): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (2): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (3): ReLU()
)], add_time=True, silent=False)
{% endraw %} {% raw %}
sr_test(learn, il_test_x, il_test_y, model_name)
bicubic: PSNR:24.11,SSIM:0.7822
SRCNN:	 PSNR:24.76,SSIM:0.8068
{% endraw %}

Report

{% raw %}
model
SRCNN(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (activate): ReLU()
)
{% endraw %} {% raw %}
learn.summary()
SRCNN
======================================================================
Layer (type)         Output Shape         Param #    Trainable 
======================================================================
Conv2d               [64, 256, 256]       15,616     True      
______________________________________________________________________
Conv2d               [32, 256, 256]       2,080      True      
______________________________________________________________________
Conv2d               [3, 256, 256]        2,403      True      
______________________________________________________________________
ReLU                 [32, 256, 256]       0          False     
______________________________________________________________________

Total params: 20,099
Total trainable params: 20,099
Total non-trainable params: 0
Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)
Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ 
Loss function : FlattenedLoss
======================================================================
Callbacks functions applied 
{% endraw %}